{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5w-QlZE9mvdY"
   },
   "source": [
    "# MNIST example with 3-conv. layer network\n",
    "\n",
    "This example demonstrates the usage of `FastaiLRFinder` with a 3-conv. layer network on the MNIST dataset.\n",
    "\n",
    "## Required Dependencies\n",
    "\n",
    "We assume that `torch` and `ignite` are already installed. We can install it using pip:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wdmoIYIvmvdp"
   },
   "outputs": [],
   "source": [
    "!pip install pytorch-ignite"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lP73mDV_nHV9"
   },
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lMphyBmmmvdw",
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "z0epEdj4mvds"
   },
   "outputs": [],
   "source": [
    "import ignite\n",
    "ignite.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NEl3QFzZmvd1"
   },
   "outputs": [],
   "source": [
    "from ignite.engine import create_supervised_trainer, create_supervised_evaluator\n",
    "from ignite.metrics import Loss, Accuracy\n",
    "from ignite.contrib.handlers import ProgressBar\n",
    "from ignite.handlers import FastaiLRFinder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L_wmAdFgmvdx"
   },
   "source": [
    "## Loading MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eZeKOgKymvdx"
   },
   "outputs": [],
   "source": [
    "mnist_pwd = \"data\"\n",
    "batch_size= 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Hgq0uZHAmvdy"
   },
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n",
    "\n",
    "trainset = MNIST(mnist_pwd, train=True, download=True, transform=transform)\n",
    "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)\n",
    "\n",
    "testset = MNIST(mnist_pwd, train=False, download=True, transform=transform)\n",
    "testloader = DataLoader(testset, batch_size=batch_size * 2, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JTLEOUqtmvdy"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6KqS5A4Umvd0",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
    "        self.conv2_drop = nn.Dropout2d()\n",
    "        self.fc1 = nn.Linear(320, 50)\n",
    "        self.fc2 = nn.Linear(50, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        x = x.view(-1, 320)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "U_EHmN2bmvd2"
   },
   "source": [
    "## Training loss (fastai)\n",
    "\n",
    "This learning rate test range follows the same procedure used by fastai. The model is trained for `num_iter` iterations while the learning rate is increased from its initial value specified by the optimizer algorithm to `end_lr`. The increase can be linear (`step_mode=\"linear\"`) or exponential (`step_mode=\"exp\"`); linear provides good results for small ranges while exponential is recommended for larger ranges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WxaXjwDhmvd2"
   },
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "criterion = nn.NLLLoss()\n",
    "model = Net()\n",
    "model.to(device)  # Move model before creating optimizer\n",
    "optimizer = optim.SGD(model.parameters(), lr=3e-4, momentum=0.9)\n",
    "\n",
    "# save the initial state for the model and the optimizer\n",
    "# to reset them later\n",
    "init_model_state = model.state_dict()\n",
    "init_opt_state = optimizer.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CTGJPVI6mvd2"
   },
   "outputs": [],
   "source": [
    "trainer = create_supervised_trainer(model, optimizer, criterion, device=device)\n",
    "ProgressBar(persist=True).attach(trainer, output_transform=lambda x: {\"batch loss\": x})\n",
    "\n",
    "lr_finder = FastaiLRFinder()\n",
    "to_save={'model': model, 'optimizer': optimizer}\n",
    "with lr_finder.attach(trainer, to_save, diverge_th=1.5) as trainer_with_lr_finder:\n",
    "    trainer_with_lr_finder.run(trainloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot the found optimal learning rate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oN0VkPapmvd5"
   },
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "ax = lr_finder.plot()\n",
    "plt.show()\n",
    "\n",
    "print(\"Suggested LR\", lr_finder.lr_suggestion())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Note, that model and optimizer were restored to their initial states when `FastaiLRFinder` finished.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EHHNoOi8mvd6"
   },
   "source": [
    "Let's compare training results: \n",
    "- a) using a fixed learning rate (`lr=3e-4`)\n",
    "- b) with applying optimal learning rate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with lr=3e-4\n",
    "trainer.run(trainloader, max_epochs=5)\n",
    "\n",
    "evaluator = create_supervised_evaluator(model, metrics={\"acc\": Accuracy(), \"loss\": Loss(nn.NLLLoss())}, device=device)\n",
    "evaluator.run(testloader)\n",
    "\n",
    "print(evaluator.state.metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training we need to get the model and the optimizer back to their initial state, to do that, we will load the initial state again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reinit model / optimizer\n",
    "model.load_state_dict(init_model_state)\n",
    "optimizer.load_state_dict(init_opt_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NcT19wqkmvd6"
   },
   "source": [
    "Let's now apply suggested learning rate to the optimizer, and train the model again with optimal learning rate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_finder.apply_suggested_lr(optimizer)\n",
    "print(optimizer.param_groups[0]['lr'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DJqgyaFnmvd7"
   },
   "outputs": [],
   "source": [
    "trainer.run(trainloader, max_epochs=5)\n",
    "\n",
    "evaluator.run(testloader)\n",
    "print(evaluator.state.metrics)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "FastaiLRFinder_MNIST.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
